from data.data_loader import create_data_loaders
from model.trainer_towers import TowersTrainer
from model.trainer_dnn import DNNTrainer
from config import Config

if __name__ == "__main__":
    config = Config()
    
    # 创建数据加载器
    train_loader, test_loader = create_data_loaders(config)
    
    # 训练模型
    trainer = TowersTrainer(config)
    #trainer = DNNTrainer(config)
    trainer.train(train_loader, test_loader)
    
    # 构建Faiss索引
    #from serving.faiss_indexer import FaissIndexer
    #indexer = FaissIndexer(config)
    #indexer.build_index(trainer.model.item_tower)